{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82846047-59b5-4780-b5ca-f933e0fdd145",
   "metadata": {},
   "outputs": [],
   "source": [
    "# https://colab.research.google.com/drive/1vMkH8LmiCCOiCo4KTTEcv-NU8_OGn0ie?usp=sharing#scrollTo=s3ucr4kRmBKk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6b5da7d5-8ed0-47cc-96d6-332d36b49365",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# torch==1.11.0\n",
    "import warprnnt_pytorch\n",
    "import torch\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "ac6a2393-f36e-46d0-8b39-09755c7c0026",
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device('cuda', 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "de1e256f-9fac-444e-a8a1-7ad11c2a6116",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = torch.FloatTensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.8, 0.1]],\n",
    "                              [[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.1, 0.1],\n",
    "                              [0.7, 0.1, 0.2, 0.1, 0.1]]]])\n",
    "warp_transducer_logits = logits.clone()\n",
    "torchaudio_logits = logits.clone()\n",
    "optimized_transducer_logits = logits.clone()\n",
    "\n",
    "logits_cuda = logits.to(device)\n",
    "warp_transducer_logits_cuda = logits_cuda.clone()\n",
    "torchaudio_logits_cuda = logits_cuda.clone()\n",
    "optimized_transducer_logits_cuda = logits_cuda.clone()\n",
    "\n",
    "targets = torch.tensor([[1, 2]], dtype=torch.int32)\n",
    "logit_lengths = torch.tensor([2], dtype=torch.int32)\n",
    "target_lengths = torch.tensor([2], dtype=torch.int32)\n",
    "\n",
    "targets_cuda = targets.to(device)\n",
    "logit_lengths_cuda = logit_lengths.to(device)\n",
    "target_lengths_cuda = target_lengths.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "cdb3b6fa-5cd2-45d7-92b9-77c65468b33d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1, 2, 3, 5])\n",
      "torch.Size([1])\n",
      "torch.Size([1, 2])\n",
      "torch.Size([1])\n"
     ]
    }
   ],
   "source": [
    "print(logits.shape)\n",
    "print(logit_lengths.shape)\n",
    "print(targets.shape)\n",
    "print(target_lengths.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "4016689c-93ec-4ce6-b062-6049b1bf5f6d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[[[0.1000, 0.6000, 0.1000, 0.1000, 0.1000],\n",
       "          [0.1000, 0.1000, 0.6000, 0.1000, 0.1000],\n",
       "          [0.1000, 0.1000, 0.2000, 0.8000, 0.1000]],\n",
       "\n",
       "         [[0.1000, 0.6000, 0.1000, 0.1000, 0.1000],\n",
       "          [0.1000, 0.1000, 0.2000, 0.1000, 0.1000],\n",
       "          [0.7000, 0.1000, 0.2000, 0.1000, 0.1000]]]], device='cuda:0',\n",
       "       requires_grad=True)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "warp_transducer_logits.requires_grad_(True)\n",
    "torchaudio_logits.requires_grad_(True)\n",
    "optimized_transducer_logits.requires_grad_(True)\n",
    "\n",
    "warp_transducer_logits_cuda.requires_grad_(True)\n",
    "torchaudio_logits_cuda.requires_grad_(True)\n",
    "optimized_transducer_logits_cuda.requires_grad_(True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f267a5e9-b936-4c3c-9472-3691f3da8e0a",
   "metadata": {},
   "source": [
    "# warp_transducer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "80b9d69c-1e29-41e6-bbb8-e73d72439a3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "warp_transducer_cpu_loss = warprnnt_pytorch.rnnt_loss(warp_transducer_logits, \n",
    "                                                      targets, \n",
    "                                                      logit_lengths, \n",
    "                                                      target_lengths,\n",
    "                                                      blank=0,\n",
    "                                                      reduction='mean',\n",
    "                                                      fastemit_lambda=0)\n",
    "\n",
    "warp_transducer_cuda_loss = warprnnt_pytorch.rnnt_loss(warp_transducer_logits_cuda, \n",
    "                                                       targets_cuda, \n",
    "                                                       logit_lengths_cuda, \n",
    "                                                       target_lengths_cuda,\n",
    "                                                       blank=0,\n",
    "                                                       reduction='mean',\n",
    "                                                       fastemit_lambda=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "e813779e-03a5-4d8a-a992-c686f74c415d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "warp_transducer, cpu_loss: tensor([4.4957], grad_fn=<_RNNTBackward>), cuda_loss: tensor([4.4957], device='cuda:0', grad_fn=<_RNNTBackward>)\n"
     ]
    }
   ],
   "source": [
    "print(f'warp_transducer, cpu_loss: {warp_transducer_cpu_loss}, cuda_loss: {warp_transducer_cuda_loss}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "9f2e7765-f0e1-41a5-b829-7a824a24b95b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(warp_transducer_cpu_loss.detach().numpy(), warp_transducer_cuda_loss.cpu().detach().numpy(), rtol=1e-7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "52b67c01-51a9-4cdc-a400-c87d1db32ce9",
   "metadata": {},
   "outputs": [],
   "source": [
    "warp_transducer_cpu_loss.backward()\n",
    "warp_transducer_cuda_loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "97b04245-cc40-4e41-8ccc-df8eb70d07a9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],\n",
      "          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],\n",
      "          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],\n",
      "\n",
      "         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],\n",
      "          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],\n",
      "          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]])\n",
      "torch.Size([1, 2, 3, 5])\n",
      "cpu\n"
     ]
    }
   ],
   "source": [
    "print(warp_transducer_logits.grad)\n",
    "print(warp_transducer_logits.grad.shape)\n",
    "print(warp_transducer_logits.grad.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "ab1906a6-903b-4339-96b2-10676c9cd55a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],\n",
      "          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],\n",
      "          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],\n",
      "\n",
      "         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],\n",
      "          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],\n",
      "          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]], device='cuda:0')\n",
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "print(warp_transducer_logits_cuda.grad)\n",
    "print(warp_transducer_logits_cuda.grad.device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "7d30b578-b87b-42cb-a5cf-025e3755be21",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(warp_transducer_logits.grad.numpy(), warp_transducer_logits_cuda.grad.cpu().numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e4a2d1bd-6769-46a9-a8dc-2b18b197639d",
   "metadata": {},
   "source": [
    "# torchaudio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "8fc126d5-70eb-4d3f-8230-75d644fe16b8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Requirement already satisfied: pip in /workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages (22.3.1)\n",
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: torchaudio==0.11.0 in /workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages (0.11.0)\n",
      "Requirement already satisfied: torch==1.11.0 in /workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages (from torchaudio==0.11.0) (1.11.0)\n",
      "Requirement already satisfied: typing-extensions in /workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages (from torch==1.11.0->torchaudio==0.11.0) (4.4.0)\n"
     ]
    }
   ],
   "source": [
    "!pip install --upgrade pip \n",
    "# !pip uninstall torchaudio -y\n",
    "!TMPDIR=$PWD pip install torchaudio==0.11.0 -i https://pypi.tuna.tsinghua.edu.cn/simple"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "15ed37df-1912-4201-b770-75e7895b755c",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchaudio"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "a71d98b1-ea2a-4248-852b-649986ea0883",
   "metadata": {},
   "outputs": [],
   "source": [
    "torchaudio_cpu_loss = torchaudio.functional.rnnt_loss(torchaudio_logits,\n",
    "                                                      targets,\n",
    "                                                      logit_lengths,\n",
    "                                                      target_lengths,\n",
    "                                                      blank=0,\n",
    "                                                      reduction='mean'\n",
    "                                                      )\n",
    "\n",
    "torchaudio_cuda_loss = torchaudio.functional.rnnt_loss(torchaudio_logits_cuda,\n",
    "                                                       targets_cuda,\n",
    "                                                       logit_lengths_cuda,\n",
    "                                                       target_lengths_cuda,\n",
    "                                                       blank=0,\n",
    "                                                       reduction='mean'\n",
    "                                                       )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "ce39d7f1-f324-4018-9534-dd9eaffacb07",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torchaudio, cpu_loss: 4.495666980743408, cuda_loss: 4.49566650390625\n"
     ]
    }
   ],
   "source": [
    "print(f'torchaudio, cpu_loss: {torchaudio_cpu_loss}, cuda_loss: {torchaudio_cuda_loss}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "59c01a88-46fd-4fef-a980-127c961029ec",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(torchaudio_cpu_loss.detach().numpy(), torchaudio_cuda_loss.cpu().detach().numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "d921805d-64d4-43b3-b744-7cef8d25a5ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "torchaudio_cpu_loss.backward()\n",
    "torchaudio_cuda_loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "353c9001-46e0-4d56-890d-8a1f517ff8fc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],\n",
      "          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],\n",
      "          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],\n",
      "\n",
      "         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],\n",
      "          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],\n",
      "          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]])\n"
     ]
    }
   ],
   "source": [
    "print(torchaudio_logits.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "2157ac5e-a0a3-4fde-9606-e5646907ea61",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],\n",
      "          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],\n",
      "          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],\n",
      "\n",
      "         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],\n",
      "          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],\n",
      "          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(torchaudio_logits_cuda.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "0b729bec-94a9-4148-9782-2ae109e16f23",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(torchaudio_logits.grad.numpy(), torchaudio_logits_cuda.grad.cpu().numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0ad21183-5f8a-48f1-8a1a-dce2625cad12",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e0328787-3e16-4016-a517-1f7028fac1ec",
   "metadata": {},
   "source": [
    "# check warp transducer with torchaudio "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "4f46fc30-8973-4568-952c-b336695fc89d",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(warp_transducer_cpu_loss.detach().numpy(), torchaudio_cuda_loss.cpu().detach().numpy(), rtol=1e-7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b4e1219b-ddcc-4954-9dcd-1502e3f7fe88",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(warp_transducer_logits.grad.numpy(), torchaudio_logits.grad.numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "a3053ed3-25a5-44ac-adbc-7bf75a1339eb",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(warp_transducer_logits_cuda.grad.cpu().numpy(), torchaudio_logits_cuda.grad.cpu().numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bad5757d-aec3-4eb1-ad9e-069159907406",
   "metadata": {
    "tags": []
   },
   "source": [
    "# optimized_transducer\n",
    "# https://github.com/csukuangfj/optimized_transducer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "c6cf867c-da40-4a82-8fc4-93700179b461",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Found existing installation: optimized-transducer 1.4\n",
      "Uninstalling optimized-transducer-1.4:\n",
      "  Successfully uninstalled optimized-transducer-1.4\n"
     ]
    }
   ],
   "source": [
    "!pip uninstall optimized_transducer -y"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "e6a49762-be49-4e9f-8262-f82fcb2d5c05",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using pip 22.3.1 from /workspace/DeepSpeech-2.x/speechx/venv/lib/python3.7/site-packages/pip (python 3.7)\n",
      "Collecting optimized_transducer\n",
      "  Using cached optimized_transducer-1.4-cp37-cp37m-linux_x86_64.whl\n",
      "Installing collected packages: optimized_transducer\n",
      "Successfully installed optimized_transducer-1.4\n"
     ]
    }
   ],
   "source": [
    "!TMPDIR=$PWD OT_CMAKE_ARGS=\"-DOT_COMPUTE_ARCHS=37 -DCMAKE_BUILD_TYPE=Release\" OT_MAKE_ARGS=\"-j\" pip install --verbose optimized_transducer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e588e8d9-cc97-4934-af00-1c728dd3d68a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import optimized_transducer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "d888e3ca-3ef0-46dc-a02b-23c6c2323a4c",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_transducer_cpu_loss = optimized_transducer.transducer_loss(optimized_transducer_logits.reshape(-1, logits.size(-1)),\n",
    "                                                                    targets,\n",
    "                                                                    logit_lengths,\n",
    "                                                                    target_lengths,\n",
    "                                                                    blank=0,\n",
    "                                                                    from_log_softmax=False,\n",
    "                                                                    reduction='mean'\n",
    "                                                                    )\n",
    "optimized_transducer_cuda_loss = optimized_transducer.transducer_loss(optimized_transducer_logits_cuda.reshape(-1, logits.size(-1)),\n",
    "                                                                     targets_cuda,\n",
    "                                                                     logit_lengths_cuda,\n",
    "                                                                     target_lengths_cuda,\n",
    "                                                                     blank=0,\n",
    "                                                                     from_log_softmax=False,\n",
    "                                                                     reduction='mean'\n",
    "                                                                     )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "id": "e3eaa117-2244-4752-84bd-907961769913",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "optimized_transducer, cpu_loss: 4.495666980743408, cuda_loss: 4.495666980743408\n"
     ]
    }
   ],
   "source": [
    "print(f'optimized_transducer, cpu_loss: {optimized_transducer_cpu_loss}, cuda_loss: {optimized_transducer_cuda_loss}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "1edad365-3d63-4744-91d0-3f7afb0b5a6e",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(optimized_transducer_cpu_loss.detach().numpy(), optimized_transducer_cuda_loss.cpu().detach().numpy(), rtol=1e-7)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "c64ad044-b741-4aef-821f-2a0283a87fb8",
   "metadata": {},
   "outputs": [],
   "source": [
    "optimized_transducer_cpu_loss.backward()\n",
    "optimized_transducer_cuda_loss.backward()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "3da7ca47-be81-440d-b954-e535d107db35",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],\n",
      "          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],\n",
      "          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],\n",
      "\n",
      "         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],\n",
      "          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],\n",
      "          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]])\n",
      "tensor([[[[-0.1312, -0.3999,  0.1770,  0.1770,  0.1770],\n",
      "          [-0.1857,  0.1225, -0.1817,  0.1225,  0.1225],\n",
      "          [-0.3209,  0.0627,  0.0693,  0.1262,  0.0627]],\n",
      "\n",
      "         [[ 0.0546, -0.2182,  0.0546,  0.0546,  0.0546],\n",
      "          [ 0.1207,  0.1207, -0.4830,  0.1207,  0.1207],\n",
      "          [-0.6926,  0.1687,  0.1865,  0.1687,  0.1687]]]], device='cuda:0')\n"
     ]
    }
   ],
   "source": [
    "print(optimized_transducer_logits.grad)\n",
    "print(optimized_transducer_logits_cuda.grad)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "2df3f240-048e-4aa4-866b-bd6c1eb4965b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(optimized_transducer_logits.grad.numpy(), optimized_transducer_logits_cuda.grad.cpu().numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "006cc844-13e0-4357-876c-fd4a9e2161c0",
   "metadata": {},
   "outputs": [],
   "source": [
    "assert torch.allclose(torchaudio_logits.grad, optimized_transducer_logits.grad)\n",
    "assert torch.allclose(torchaudio_logits_cuda.grad, optimized_transducer_logits_cuda.grad)\n",
    "assert torch.allclose(torchaudio_logits.grad, optimized_transducer_logits_cuda.grad.cpu())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "39c4d38b-6a06-4d88-963c-0d5bc4400b5f",
   "metadata": {},
   "source": [
    "# paddle rnnt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "6bfa568e-1463-4040-9301-0c9eabb64b72",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "warp_transducer, cpu_loss: Tensor(shape=[1], dtype=float32, place=Place(cpu), stop_gradient=False,\n",
      "       [4.49566650])\n",
      "Tensor(shape=[1, 2, 3, 5], dtype=float32, place=Place(cpu), stop_gradient=False,\n",
      "       [[[[-0.13116686, -0.39992690,  0.17703126,  0.17703126,  0.17703126],\n",
      "          [-0.18572757,  0.12247056, -0.18168412,  0.12247056,  0.12247056],\n",
      "          [-0.32091251,  0.06269141,  0.06928471,  0.12624499,  0.06269141]],\n",
      "\n",
      "         [[ 0.05456069, -0.21824276,  0.05456069,  0.05456069,  0.05456069],\n",
      "          [ 0.12073959,  0.12073959, -0.48295838,  0.12073959,  0.12073959],\n",
      "          [-0.69258857,  0.16871126,  0.18645476,  0.16871126,  0.16871126]]]])\n"
     ]
    }
   ],
   "source": [
    "import paddle\n",
    "paddle.set_device('cpu')\n",
    "\n",
    "logits = paddle.to_tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.8, 0.1]],\n",
    "                              [[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.1, 0.1],\n",
    "                              [0.7, 0.1, 0.2, 0.1, 0.1]]]], dtype=paddle.float32, stop_gradient=False)\n",
    "\n",
    "targets = paddle.to_tensor([[1, 2]], dtype=paddle.int32)\n",
    "logit_lengths = paddle.to_tensor([2], dtype=paddle.int32)\n",
    "target_lengths = paddle.to_tensor([2], dtype=paddle.int32)\n",
    "\n",
    "# cpu not do log_softmax by defualt\n",
    "logprob = paddle.nn.functional.log_softmax(logits)\n",
    "rnnt_loss_cpu = paddle.nn.functional.rnnt_loss(logprob, \n",
    "                                                targets, \n",
    "                                                logit_lengths, \n",
    "                                                target_lengths,\n",
    "                                                blank=0,\n",
    "                                                reduction='mean',\n",
    "                                                fastemit_lambda=0)\n",
    "\n",
    "print(f'warp_transducer, cpu_loss: {rnnt_loss_cpu}')\n",
    "rnnt_loss_cpu.backward()\n",
    "logits_grad_cpu = logits.grad.clone()\n",
    "print(logits_grad_cpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "c1fc0115-d246-4d22-b371-ba293a6e6962",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "warp_transducer, cuda_loss: Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [4.49566650])\n",
      "Tensor(shape=[1, 2, 3, 5], dtype=float32, place=Place(gpu:0), stop_gradient=False,\n",
      "       [[[[-0.13116689, -0.39992687,  0.17703123,  0.17703123,  0.17703123],\n",
      "          [-0.18572746,  0.12247057, -0.18168400,  0.12247057,  0.12247057],\n",
      "          [-0.32091236,  0.06269139,  0.06928468,  0.12624492,  0.06269139]],\n",
      "\n",
      "         [[ 0.05456069, -0.21824265,  0.05456069,  0.05456069,  0.05456069],\n",
      "          [ 0.12073954,  0.12073954, -0.48295826,  0.12073954,  0.12073954],\n",
      "          [-0.69258809,  0.16871123,  0.18645471,  0.16871123,  0.16871123]]]])\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "W1215 04:46:47.980068 56399 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 10.2, Runtime API Version: 10.2\n",
      "W1215 04:46:47.986979 56399 gpu_resources.cc:91] device: 0, cuDNN Version: 7.6.\n"
     ]
    }
   ],
   "source": [
    "paddle.set_device('gpu')\n",
    "\n",
    "logits = paddle.to_tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.8, 0.1]],\n",
    "                              [[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.1, 0.1],\n",
    "                              [0.7, 0.1, 0.2, 0.1, 0.1]]]], dtype=paddle.float32, stop_gradient=False)\n",
    "\n",
    "targets = paddle.to_tensor([[1, 2]], dtype=paddle.int32)\n",
    "logit_lengths = paddle.to_tensor([2], dtype=paddle.int32)\n",
    "target_lengths = paddle.to_tensor([2], dtype=paddle.int32)\n",
    "\n",
    "# gpu do log_softmax by default\n",
    "rnnt_loss_gpu = paddle.nn.functional.rnnt_loss(logits, \n",
    "                                               targets, \n",
    "                                               logit_lengths, \n",
    "                                                target_lengths,\n",
    "                                                blank=0,\n",
    "                                                reduction='mean',\n",
    "                                                fastemit_lambda=0)\n",
    "print(f'warp_transducer, cuda_loss: {rnnt_loss_gpu}')\n",
    "rnnt_loss_gpu.backward()\n",
    "logits_grad_gpu = logits.grad.clone()\n",
    "print(logits_grad_gpu)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "e8acd761-7b9a-48f9-a4e8-a74f26d41b2b",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(rnnt_loss_cpu.item(), rnnt_loss_gpu.item())\n",
    "np.testing.assert_allclose(rnnt_loss_gpu.item(), warp_transducer_cuda_loss.cpu().detach().numpy())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "c6957700-e87b-48e2-931a-715e7884eefe",
   "metadata": {},
   "outputs": [],
   "source": [
    "np.testing.assert_allclose(logits_grad_gpu.numpy(), logits_grad_cpu.numpy(), rtol=1e-6)\n",
    "np.testing.assert_allclose(warp_transducer_logits.grad.numpy(), logits_grad_gpu.numpy(), rtol=1e-6)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "462792a2-e656-4ccd-be4e-da7df473968d",
   "metadata": {},
   "source": [
    "# time cost"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cf8bc19e-8f11-4b93-93c0-c55b96362bcb",
   "metadata": {},
   "source": [
    "## warp_transducer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "1aa5427c-bfff-4e1a-87dd-df51cafda017",
   "metadata": {},
   "outputs": [],
   "source": [
    "logits = torch.FloatTensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.8, 0.1]],\n",
    "                              [[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.1, 0.1],\n",
    "                              [0.7, 0.1, 0.2, 0.1, 0.1]]]])\n",
    "warp_transducer_logits = logits.clone()\n",
    "torchaudio_logits = logits.clone()\n",
    "optimized_transducer_logits = logits.clone()\n",
    "\n",
    "logits_cuda = logits.to(device)\n",
    "warp_transducer_logits_cuda = logits_cuda.clone()\n",
    "torchaudio_logits_cuda = logits_cuda.clone()\n",
    "optimized_transducer_logits_cuda = logits_cuda.clone()\n",
    "\n",
    "targets = torch.tensor([[1, 2]], dtype=torch.int32)\n",
    "logit_lengths = torch.tensor([2], dtype=torch.int32)\n",
    "target_lengths = torch.tensor([2], dtype=torch.int32)\n",
    "\n",
    "targets_cuda = targets.to(device)\n",
    "logit_lengths_cuda = logit_lengths.to(device)\n",
    "target_lengths_cuda = target_lengths.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "29393009-f395-4254-9cc7-7243b46c7e77",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "260 µs ± 567 ns per loop (mean ± std. dev. of 7 runs, 1000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "warp_transducer_cuda_loss = warprnnt_pytorch.rnnt_loss(warp_transducer_logits_cuda, \n",
    "                                                       targets_cuda, \n",
    "                                                       logit_lengths_cuda, \n",
    "                                                       target_lengths_cuda,\n",
    "                                                       blank=0,\n",
    "                                                       reduction='mean',\n",
    "                                                       fastemit_lambda=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "a2e877e8-f5ae-49e5-9076-1702e84a5fe2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "109 µs ± 5.39 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "warp_transducer_cpu_loss = warprnnt_pytorch.rnnt_loss(warp_transducer_logits, \n",
    "                                                      targets, \n",
    "                                                      logit_lengths, \n",
    "                                                      target_lengths,\n",
    "                                                      blank=0,\n",
    "                                                      reduction='mean',\n",
    "                                                      fastemit_lambda=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8e285af-c577-4691-a491-66e086a52756",
   "metadata": {},
   "source": [
    "## paddle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "58409d03-02aa-4e1a-9546-1e744e172bfd",
   "metadata": {},
   "outputs": [],
   "source": [
    "paddle.set_device('gpu')\n",
    "\n",
    "logits = paddle.to_tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.8, 0.1]],\n",
    "                              [[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.1, 0.1],\n",
    "                              [0.7, 0.1, 0.2, 0.1, 0.1]]]], dtype=paddle.float32, stop_gradient=False)\n",
    "\n",
    "targets = paddle.to_tensor([[1, 2]], dtype=paddle.int32)\n",
    "logit_lengths = paddle.to_tensor([2], dtype=paddle.int32)\n",
    "target_lengths = paddle.to_tensor([2], dtype=paddle.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "71037a95-e590-4d88-8430-ca3049e6ccda",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "133 µs ± 304 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "# gpu do log_softmax by default\n",
    "rnnt_loss_gpu = paddle.nn.functional.rnnt_loss(logits, \n",
    "                                               targets, \n",
    "                                               logit_lengths, \n",
    "                                                target_lengths,\n",
    "                                                blank=0,\n",
    "                                                reduction='mean',\n",
    "                                                fastemit_lambda=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "6a17ad3b-2f5e-45e9-8561-31b9fced9ee6",
   "metadata": {},
   "outputs": [],
   "source": [
    "paddle.set_device('cpu')\n",
    "\n",
    "logits = paddle.to_tensor([[[[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.6, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.8, 0.1]],\n",
    "                              [[0.1, 0.6, 0.1, 0.1, 0.1],\n",
    "                              [0.1, 0.1, 0.2, 0.1, 0.1],\n",
    "                              [0.7, 0.1, 0.2, 0.1, 0.1]]]], dtype=paddle.float32, stop_gradient=False)\n",
    "\n",
    "targets = paddle.to_tensor([[1, 2]], dtype=paddle.int32)\n",
    "logit_lengths = paddle.to_tensor([2], dtype=paddle.int32)\n",
    "target_lengths = paddle.to_tensor([2], dtype=paddle.int32)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "42e0930b-3796-4c47-b45f-22bc7a432e55",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "39.8 µs ± 375 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
     ]
    }
   ],
   "source": [
    "%%timeit\n",
    "logprob = paddle.nn.functional.log_softmax(logits)\n",
    "rnnt_loss_cpu = paddle.nn.functional.rnnt_loss(logprob, \n",
    "                                               targets, \n",
    "                                               logit_lengths, \n",
    "                                                target_lengths,\n",
    "                                                blank=0,\n",
    "                                                reduction='mean',\n",
    "                                                fastemit_lambda=0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "244c0df0-33d1-47d6-acbb-278187c3665a",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
